from typing import *


class Solution:
    def sortMatrix(self, grid: List[List[int]]) -> List[List[int]]:
        n = len(grid)

        for i in range(n):
            k, a = i, []
            for j in range(n - i):
                a.append(grid[k][j])
                k += 1
            a.sort()
            k = i
            for j in range(n - i):
                grid[k][j] = a.pop()
                k += 1

        for i in range(1, n):
            k, a = 0, []
            for j in range(i, n):
                a.append(grid[k][j])
                k += 1
            a.sort(reverse=True)
            k = 0
            for j in range(i, n):
                grid[k][j] = a.pop()
                k += 1
        return grid


s = Solution()
print(s.sortMatrix([[1, 7, 3], [9, 8, 2], [4, 5, 6]]))
